#include "torch/csrc/jit/ir.h"

#include <algorithm>
#include <unordered_map>

#include "torch/csrc/jit/assertions.h"
#include "torch/csrc/jit/interned_strings.h"
#include "torch/csrc/jit/passes/common_subexpression_elimination.h"
#include "torch/csrc/jit/node_hashing.h"
#include "torch/csrc/utils/functional.h"
#include "torch/csrc/utils/hash.h"

namespace torch { namespace jit {

// The function implements common subexpression elimination.
// Since the nodes are visited in topological order, one pass is enough.
void EliminateCommonSubexpression(Block * block,
                                  std::function<Node*(Node*)> parent_lookup_fn) {
  std::unordered_set<Node*, HashNode, EqualNode> subexprs;
  for (auto it = block->nodes().begin(); it != block->nodes().end(); ++ it) {
    auto node = *it;
    if (node->kind() == prim::PythonOp
        || node->kind() == prim::Print
       ) {
      // Do NOT have enough information to do CSE on these nodes.
      continue;
    }

    if (!node->blocks().empty()) {
      // Traverse sub-blocks.
      for (auto block : node->blocks()) {
        EliminateCommonSubexpression(block,
          [&](Node *n) {
            auto existing = subexprs.find(n);
            if (existing != subexprs.end()) {
              return *existing;
            }

            return parent_lookup_fn(n);
          });
      }

      continue;
    }

    // Check for CSE opportunities in the parent block.
    auto parent_lookup = parent_lookup_fn(node);
    if (parent_lookup) {
      node->replaceAllUsesWith(parent_lookup);
      it.destroyCurrent();
      continue;
    }

    // Check whether the same subexpression already exists.
    auto subit = subexprs.insert(node);
    if (!subit.second) {
      // Subexpression exists, replace the uses of node, and destroy it.
      auto existing = *subit.first;
      node->replaceAllUsesWith(existing);
      // Destroy the node.
      it.destroyCurrent();
    }
  }
}

void EliminateCommonSubexpression(std::shared_ptr<Graph>& graph) {
  EliminateCommonSubexpression(graph->block());
}

}}
